import os
import h5py
import util.landscape.util as util
import torch.distributed as dist
from util.trainer.model import load_model
from util.landscape.projection import project_1D_coordinate,  project_2D_coordinate
import numpy as np
from util.landscape.projection import cal_cos
from sklearn.decomposition import PCA
import torch
from sklearn.preprocessing import StandardScaler

def setup_direction(v1,v2, basis= 'orthonormal',scale1=1,scale2=1):
    if basis == 'orthonormal':
        dx = v1 / v1.norm()
        v2_orthogonal = v2 - project_1D_coordinate(v2, dx)*dx
        dy = v2_orthogonal / v2_orthogonal.norm()
    elif basis == 'orthogonal':
        dx = v1 
        dy = v2 - project_1D_coordinate(v2, v1)*v1
    elif basis == 'orthoscale':
        dx = v1/v1.norm()*scale1
        v2_orthogonal = v2 - project_1D_coordinate(v2, v1)*v1
        dy=v2_orthogonal/v2_orthogonal.norm()*scale2
    elif basis == 'scale':
        dx=v1/v1.norm()*scale1
        dy=v2/v2.norm()*scale2
    else:
        dx=v1
        dy=v2
    return dx, dy

def get_origin_and_directions(args, rank, local_rank):
    """origin in tensorlist, but each direction in tensor"""
    if args.animation:
        orgn_dir_file = os.path.join(os.path.join(args.path_to_orgn_dir_file, args.id), "frame_"+str(args.current_epoch)+".h5")
    else:
        orgn_dir_file = os.path.join(args.path_to_orgn_dir_file, args.id+".h5")
    if not os.path.exists(orgn_dir_file):
        if args.dir_mthd=="trajs": 
            for root1, _, files1 in os.walk(f"../checkpoints/{args.trajectories[0]}"):
                files1 = sorted(files1, key=util.extract_number)
                origin = util.get_weights(load_model(os.path.join(root1, files1[-1]), args))
                init = util.get_weights(load_model(os.path.join(root1, files1[0]), args))
                model_1 = load_model(os.path.join(root1, files1[-1]), args)
                vec1 = util.get_diff_weights(init, model_1.parameters())
            for root2, _, files2 in os.walk(f"../checkpoints/{args.trajectories[1]}"):
                files2 = sorted(files2, key=util.extract_number)
                util.check_params_equal(init, util.get_weights(load_model(os.path.join(root2, files2[0]), args)))
                model_2 = load_model(os.path.join(root2, files2[-1]), args)
                vec2 = util.get_diff_weights(init, model_2.parameters())
            dx, dy = setup_direction(vec1, vec2, args.basis)

            # dx = util.vec_to_tensorlist(dx, origin)
            # dy = util.vec_to_tensorlist(dy, origin)
            # util.normalize_directions_for_weights(dx, origin)
            # util.normalize_directions_for_weights(dy, origin)
            # dx = util.tensorlist_to_tensor(dx)
            # dy = util.tensorlist_to_tensor(dy)
            # print(cal_cos(dx,dy))


            directions = [dx, dy]
            if rank==0 or (rank is None):
                print(f"These directions are from common start to corresponding ends of two trajectories {args.trajectories[0]} and {args.trajectories[1]}")
        if args.dir_mthd=="optm_pca":
            D_matrix = [] 
            for traj in args.trajectories:
                for root, _, files in os.walk(f"../checkpoints/{traj}"):
                    files = sorted(files, key=util.extract_number)
                    if traj==360702:
                        model = load_model(os.path.join(root, files[0]), args)
                        for param in model.parameters():
                            param.requires_grad = False
                        # init = util.tensorlist_to_tensor(util.get_weights(model))
                        # D_matrix.append(init.cpu().numpy())
                        origin = util.get_weights(load_model(os.path.join(root, files[-1]), args))
                        model = load_model(os.path.join(root, files[-1]), args)
                        for param in model.parameters():
                            param.requires_grad = False
                        vec = util.tensorlist_to_tensor(util.get_weights(model))
                        D_matrix.append(vec.cpu().numpy())
                    else:
                        model = load_model(os.path.join(root, files[-1]), args)
                        for param in model.parameters():
                            param.requires_grad = False
                        vec = util.tensorlist_to_tensor(util.get_weights(model))
                        D_matrix.append(vec.cpu().numpy())
            # D_matrix_scaled = (np.array(D_matrix)-np.array(D_matrix).mean(axis=0))* np.sqrt(np.array([1,0,1,1])[:, np.newaxis]) 
            pca = PCA(n_components=2)
            pca.fit(D_matrix)
            pca.components_[0]
            dx = torch.from_numpy(pca.components_[0]).to(f"cuda:{local_rank}")
            dy = torch.from_numpy(pca.components_[1]).to(f"cuda:{local_rank}")
            # origin = torch.from_numpy(pca.mean_).to(f"cuda:{local_rank}")
            # origin = util.vec_to_tensorlist(origin, model.parameters())
            directions = [dx, dy]
            if rank==0 or (rank is None):
                print(f"These directions are derived from first two components of pca of vectors from origin to optima of trajectories.")
        if args.dir_mthd=="fpp":
            for root, _, files in os.walk(f"../checkpoints/{args.trajectories[0]}"):
                files = sorted(files, key=util.extract_number)
                step_dir=[]
                origin = util.get_weights(load_model(os.path.join(root, files[args.current_epoch]), args))
                mov_dir=util.get_diff_weights(load_model(os.path.join(root, files[args.current_epoch]), args).parameters(), load_model(os.path.join(root, files[args.current_epoch+1]), args).parameters())
                for i in range(args.current_epoch,args.current_epoch+args.fpp_scope):
                    step_dir.append(util.get_diff_weights(load_model(os.path.join(root, files[i]), args).parameters(), load_model(os.path.join(root, files[i+1]), args).parameters()))    
                C=torch.stack(step_dir)
                A=C@C.T
                u, s, _ = torch.linalg.svd(A)
                dx=C.T@u[:,1]
                dy=C.T@u[:,0]
                dx=dx/dx.norm()
                dy=dy/dy.norm()

                coordy, coordx = project_2D_coordinate(mov_dir,dy,dx, args.basis)
                first = coordy*dy+coordx*dx
                second = coordx*dy-coordy*dx

                if args.current_epoch!=args.start_epoch:
                    dist.barrier() # ensure all rank can read origin and directions file
                    f = h5py.File(os.path.join(os.path.join(args.path_to_orgn_dir_file, args.id), "frame_"+str(args.current_epoch-1)+".h5"), "r")
                    last_directions = util.read_tensorlist(f, "directions", local_rank)
                    f.close()
                    if (cal_cos(first,last_directions[1])>0) ^ (cal_cos(second,last_directions[0])>0):
                        second = -second

                first=first/first.norm()
                second=second/second.norm()
                # if abs(coordy)>abs(coordx):
                #     second = dx-project_1D_coordinate(dx,first)*first
                # else:
                #     second = (dy-project_1D_coordinate(dy,first)*first)
                # second = second/second.norm()
                directions = [second, first]

        if rank==0 or (rank is None):
            if args.animation:
                if not os.path.exists(os.path.join(args.path_to_orgn_dir_file, args.id)):
                    os.makedirs(os.path.join(args.path_to_orgn_dir_file, args.id))
            f = h5py.File(orgn_dir_file, "w")
            util.write_tensorlist(f, 'origin', origin)
            util.write_tensorlist(f, 'directions', directions)
            f.close()
    else:
        if rank==0 or (rank is None):
            print(f"These directions are derived via same method as {args.id}.")
        # ensure all rank can read origin and directions file
        dist.barrier()
        f = h5py.File(orgn_dir_file, "r")
        origin = util.read_tensorlist(f, "origin", local_rank)
        directions = util.read_tensorlist(f, "directions", local_rank)
        f.close()
    return origin, directions
# dx = util.vec_to_tensorlist(dx, origin)
            # dy = util.vec_to_tensorlist(dy, origin)
            # util.normalize_directions_for_weights(dx, origin)
            # util.normalize_directions_for_weights(dy, origin)
            # dx = util.tensorlist_to_tensor(dx)
            # dy = util.tensorlist_to_tensor(dy)
            # print(cal_cos(dx,dy))

        # if args.dir_mthd=="optm_traj": 
        #     for root1, _, files1 in os.walk(f"../checkpoints/{args.trajectories[0]}"):
        #         files1 = sorted(files1, key=util.extract_number)
        #         origin = util.get_weights(load_model(os.path.join(root1, files1.pop(0)), args))
        #         model_1 = load_model(os.path.join(root1, files1[-1]), args)
        #         vec1 = util.get_diff_weights(origin, model_1.parameters())
        #     for root2, _, files2 in os.walk(f"../checkpoints/{args.trajectories[1]}"):
        #         files2 = sorted(files2, key=util.extract_number)
        #         util.check_params_equal(origin, util.get_weights(load_model(os.path.join(root2, files2.pop(0)), args)))
        #         model_2 = load_model(os.path.join(root2, files2[-1]), args)
        #         vec2 = util.get_diff_weights(origin, model_2.parameters())
        #     dx, dy = setup_direction(vec1, vec2, args.basis)
        #     directions = [dx, dy]
        #     if rank==0 or (rank is None):
        #         print(f"These directions are from common start to corresponding ends of two trajectories {args.trajectories[0]} and {args.trajectories[1]}")



        # if args.dir_mthd=="optm_pca":
        #     D_matrix = [] 
        #     origin = None
        #     for traj in args.trajectories:
        #         for root, _, files in os.walk(f"../checkpoints/{traj}"):
        #             files = sorted(files, key=util.extract_number)
        #             if origin is None:
        #                 origin = util.get_weights(load_model(os.path.join(root, files.pop(0)), args))
        #             else:
        #                 util.check_params_equal(origin, util.get_weights(load_model(os.path.join(root, files.pop(0)), args)))
        #             model = load_model(os.path.join(root, files[-1]), args)
        #             for param in model.parameters():
        #                 param.requires_grad = False
        #             vec = util.get_diff_weights(origin, model.parameters())
        #             D_matrix.append(vec.cpu().numpy())
        #     pca = PCA(n_components=2)
        #     pca.fit(np.array(D_matrix))
        #     dx = torch.from_numpy(pca.components_[0]).to(f"cuda:{local_rank}")
        #     dy = torch.from_numpy(pca.components_[1]).to(f"cuda:{local_rank}")
        #     origin = torch.from_numpy(pca.mean_).to(f"cuda:{local_rank}")
        #     init=None
        #     for i in range(len(D_matrix)):
        #         if init is None:
        #             init = torch.tensor(D_matrix[i]).to(f"cuda:{local_rank}")-origin 
        #         else:
        #             init += torch.tensor(D_matrix[i]).to(f"cuda:{local_rank}")-origin 
        #     print(init)
        #     origin = util.vec_to_tensorlist(origin, model.parameters())
        #     directions = [dx, dy]
            # vec = None
            # for traj in args.trajectories:
            #     for root, _, files in os.walk(f"../checkpoints/{traj}"):
            #         files = sorted(files, key=util.extract_number)
            #         model = load_model(os.path.join(root, files[-1]), args)
            #         for param in model.parameters():
            #             param.requires_grad = False
            #         if vec is None:
            #             vec = util.get_diff_weights(origin, model.parameters())
            #         else:
            #             vec += util.get_diff_weights(origin, model.parameters())
            # print(vec)